import torch
import torch.nn as nn
from knowledge_tracing.args import ARGS
from datasets.dataset_parser import Constants
from knowledge_tracing.network.util_network import get_constraint_losses
from knowledge_tracing.network.util_network import get_laplacian_loss, get_question_similarity_matrix


class DKT(nn.Module):
    """
    LSTM based model
    """
    def __init__(self, device, encoder_features, hidden_dim, num_layers=1, dropout=0.0, use_laplacian=False):
        super().__init__()

        self.device = device
        self.encoder_features = encoder_features  # 'interaction'

        # embedding
        self.encoder_embedding_layers = torch.nn.ModuleDict({
            feature.name: feature.embed_layer(dim)
            for feature, dim in encoder_features
        })
        self._embed_dim = encoder_features[0][1]

        self._question_num = Constants(ARGS.dataset_name, ARGS.data_root).NUM_ITEMS

        self._hidden_dim = hidden_dim
        self._num_layers = num_layers
        self._lstm = nn.LSTM(self._embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout)
        self._decoder = nn.Sequential(
            nn.Linear(hidden_dim, self._question_num + 1),
            nn.Sigmoid()
        )

        # xavier initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.augmentations = ARGS.augmentations
        self.use_laplacian = use_laplacian  # flag for qDKT
        if use_laplacian:
            self.question_similarity_matrix = get_question_similarity_matrix().to(device)

    def _init_hidden(self, batch_size):
        """
        initialize hidden layer as zero tensor
        batch_size: single integer
        """
        weight = next(self.parameters())
        return (weight.new_zeros(self._num_layers, batch_size, self._hidden_dim),
                weight.new_zeros(self._num_layers, batch_size, self._hidden_dim))

    @staticmethod
    def _get_item_idx(interaction_idx):
        item_idx = (interaction_idx + 1) // 2
        item_idx = item_idx.long().unsqueeze(-1)
        return item_idx

    def _shift_tensor(self, x):
        """
        Shift tensor of shape (bsz, seq_len) by one
        """
        bsz = x.shape[0]
        shifted_x = torch.cat((torch.zeros([bsz, 1], dtype=torch.long, device=self.device), x[:, :-1]), dim=-1)
        return shifted_x

    def forward(self, data):
        """
        Args:
            data: A dictionary of dictionary of tensors. keys ('ori', 'rep', 'ins', 'del')
            represents whether the data is an original or augmented version.
        """
        batch_size = data['ori']['interaction_idx'].shape[0]
        hidden = self._init_hidden(batch_size)

        last_output = {}
        aug_losses = {}
        laplacian_loss = None

        for aug in data:
            if aug == 'ori' or self.training:
                aug_interaction_idx = self._shift_tensor(data[aug]['interaction_idx'])
                aug_enc_embed = self.encoder_embedding_layers['interaction_idx'](aug_interaction_idx)

                # forward
                aug_output, _ = self._lstm(aug_enc_embed, (hidden[0].detach(), hidden[1].detach()))
                # (bsz, seq_len, hidden_dim)
                aug_output = self._decoder(aug_output)
                aug_item_idx = self._get_item_idx(data[aug]['interaction_idx'])
                last_output[aug] = aug_output.gather(-1, aug_item_idx)

                if aug == 'ori' and self.training and self.use_laplacian:
                    # qDKT: laplacian regularization
                    laplacian_loss = get_laplacian_loss(aug_output,
                                                        data[aug]['loss_mask'],
                                                        self.question_similarity_matrix)

        # constraint losses
        if self.training:
            aug_losses = get_constraint_losses(data, last_output)

        if laplacian_loss is not None:
            aug_losses['lap'] = laplacian_loss

        if len(aug_losses) == 0:
            return last_output, None
        else:
            return last_output, aug_losses
